iT邦幫忙

2018 iT 邦幫忙鐵人賽
DAY 23
0

Introduction

生成對抗網路(Generative Adversarial Network, GAN)機器學習新趨勢

GAN 同時訓練兩個神經網路,生成器(Generator, G)和鑑別器(Discriminator, D),互相對抗激勵而越來越強。

  • 生成器(Generator, G)輸入以常態分布的隨機噪音向量(Z),產生相似於 MNIST 資料集的生成圖像,用以訓練鑑別器。
  • 鑑別器(Discriminator, D)輸入 MNIST 圖像及生成圖像,並試圖判別 MNIST 圖像與生成圖像兩者之間的區別。

D(G(Z)) 為鑑別器(Discriminator, D)的輸出值:

  • 生成器(Generator, G),試圖使鑑別器(Discriminator, D)輸出最小化,判別失敗。
  • 鑑別器(Discriminator, D),試圖讓自己的輸出最大化,判別成功。

訓練過程反覆進行,GAN 兩個神經網路最後會收斂到一個平衡點,得到一個生成模型輸入隨機數字後可產生相似於 MNIST 資料集的生成圖像。

GAN flow

當然,GAN 兩個神經網路之間對抗(adversarial)的,可是是圖像以外的其他資料型式。

Tasks

引用物件。

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os

import cntk as C
import cntk.tests.test_utils
cntk.tests.test_utils.set_device_from_pytest_env()
C.cntk_py.set_fixed_random_seed(1)

%matplotlib inline

設定為快速模式。

isFast = True

1.資料讀取(Data reading):

GAN 是非監督學習模型,所以只讀取資料集中的特徵值(features),而不使用標籤值(labels)。

找尋本地端 MNIST 手寫數字資料集檔案。

data_found = False
for data_dir in [os.path.join("..", "Examples", "Image", "DataSets", "MNIST"),
                 os.path.join("data", "MNIST")]:
    train_file = os.path.join(data_dir, "Train-28x28_cntk_text.txt")
    if os.path.isfile(train_file):
        data_found = True
        break
        
if not data_found:
    raise ValueError("Please generate the data by completing CNTK 103 Part A")
    
print("Data directory is {0}".format(data_dir))

宣告函式:create_reader,讀取資料集。

def create_reader(path, is_training, input_dim, label_dim):
    deserializer = C.io.CTFDeserializer(
        filename = path,
        streams = C.io.StreamDefs(
            labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),
            features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False
            )
        )
    )
    return C.io.MinibatchSource(
        deserializers = deserializer,
        randomize = is_training,
        max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1
    )

宣告函式:noise_sample,隨機產生噪音樣本,介於[-1,1]之間的常態分佈數值。

np.random.seed(123)
def noise_sample(num_samples):
    return np.random.uniform(
        low = -1.0,
        high = 1.0,
        size = [num_samples, g_input_dim]        
    ).astype(np.float32)

2.資料處理(Data preprocessing):

無。

3.建立模型(Model creation):

設定超參數。

g_input_dim = 100
g_hidden_dim = 128
g_output_dim = d_input_dim = 784
d_hidden_dim = 128
d_output_dim = 1

生成器(Generator, G),全連接層(full connected layer) x 2:
輸入資料:100 維度的隨機向量,使用 tanh 作為激活函式(activation function)。
輸出資料:28 x 28 = 784 維度的像素向量。

鑑別器(Discriminator, D),全連接層(full connected layer):
輸入資料:28 x 28 = 784 維度的像素向量,MNIST 資料集及生成器(Generator, G)產生的資料集。
輸出資料:是否為 MNIST 的機率值。

宣告函式:generator,生成器。

def generator(z):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(g_hidden_dim, activation = C.relu)(z)
        return C.layers.Dense(g_output_dim, activation = C.tanh)(h1)

宣告函式:discriminator,鑑別器。

def discriminator(x):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(d_hidden_dim, activation = C.relu)(x)
        return C.layers.Dense(d_output_dim, activation = C.sigmoid)(h1)

批次大小(minibatch):1024
學習速率(learning rate):0.0005
訓練回合(iterations):300

minibatch_size = 1024
num_minibatches = 300 if isFast else 40000
lr = 0.00005

4.訓練模型(Learning the model):

宣告函式:build_graph,建立執行流程。

def build_graph(noise_shape, image_shape,
                G_progress_printer, D_progress_printer):
    input_dynamic_axes = [C.Axis.default_batch_axis()]
    Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)
    X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)
    X_real_scaled = 2*(X_real / 255.0) - 1.0

    # 建立生成及鑑別模型
    X_fake = generator(Z)
    D_real = discriminator(X_real_scaled)
    D_fake = D_real.clone(
        method = 'share',
        substitutions = {X_real_scaled.output: X_fake.output}
    )

    # 損失函式及最佳化
    G_loss = 1.0 - C.log(D_fake)
    D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))

    G_learner = C.fsadagrad(
        parameters = X_fake.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )
    D_learner = C.fsadagrad(
        parameters = D_real.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )

    # 初始化訓練參數
    G_trainer = C.Trainer(
        X_fake,
        (G_loss, None),
        G_learner,
        G_progress_printer
    )
    D_trainer = C.Trainer(
        D_real,
        (D_loss, None),
        D_learner,
        D_progress_printer
    )

    return X_real, X_fake, Z, G_trainer, D_trainer

宣告函式:train,訓練模型。

def train(reader_train):
    k = 2
 
    print_frequency_mbsize = num_minibatches // 50
    pp_G = C.logging.ProgressPrinter(print_frequency_mbsize)
    pp_D = C.logging.ProgressPrinter(print_frequency_mbsize * k)

    X_real, X_fake, Z, G_trainer, D_trainer = \
        build_graph(g_input_dim, d_input_dim, pp_G, pp_D)
    
    input_map = {X_real: reader_train.streams.features}
    for train_step in range(num_minibatches):

        # 訓練鑑別器
        for gen_train_step in range(k):
            Z_data = noise_sample(minibatch_size)
            X_data = reader_train.next_minibatch(minibatch_size, input_map)
            if X_data[X_real].num_samples == Z_data.shape[0]:
                batch_inputs = {X_real: X_data[X_real].data, 
                                Z: Z_data}
                D_trainer.train_minibatch(batch_inputs)

        # 訓練生成器
        Z_data = noise_sample(minibatch_size)
        batch_inputs = {Z: Z_data}
        G_trainer.train_minibatch(batch_inputs)

        G_trainer_loss = G_trainer.previous_minibatch_loss_average

    return Z, X_fake, G_trainer_loss

開始訓練。

reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)

G_input, G_output, G_trainer_loss = train(reader_train)

宣告函式:plot_images,生成圖像。

def plot_images(images, subplot_shape):
    plt.style.use('ggplot')
    fig, axes = plt.subplots(*subplot_shape)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image.reshape(28, 28), vmin = 0, vmax = 1.0, cmap = 'gray')
        ax.axis('off')
    plt.show()

使用訓練完成的模型,輸入隨機噪音,並生成圖像。

noise = noise_sample(36)
images = G_output.eval({G_input: noise})
plot_images(images, subplot_shape =[6, 6])

上一篇
藝術風格轉換
下一篇
深度捲積生成對抗網路
系列文
探索 Microsoft CNTK 機器學習工具30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言